import json
from tqdm import tqdm
import time
import argparse

import argparse
import os
import copy
from easydict import EasyDict as edict
import numpy as np
import json
import torch
from PIL import Image, ImageDraw, ImageFont
from collections import defaultdict
import shutil

# Grounding DINO
import GroundingDINO.groundingdino.datasets.transforms as T
from GroundingDINO.groundingdino.models import build_model
from GroundingDINO.groundingdino.util import box_ops
from GroundingDINO.groundingdino.util.slconfig import SLConfig
from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap

# segment anything
from segment_anything import (
    sam_model_registry,
    sam_hq_model_registry,
    SamPredictor
)
import cv2
import numpy as np
import matplotlib.pyplot as plt


def load_image(image_path):
    # load image
    image_pil = Image.open(image_path).convert("RGB")  # load image

    transform = T.Compose(
        [
            T.RandomResize([800], max_size=1333),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    )
    image, _ = transform(image_pil, None)  # 3, h, w
    return image_pil, image


def load_model(model_config_path, model_checkpoint_path, device):
    args = SLConfig.fromfile(model_config_path)
    args.device = device
    model = build_model(args)
    checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
    load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
    print(load_res)
    _ = model.eval()
    return model


def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
    caption = caption.lower()
    caption = caption.strip()
    if not caption.endswith("."):
        caption = caption + "."
    model = model.to(device)
    image = image.to(device)
    with torch.no_grad():
        outputs = model(image[None], captions=[caption])
    logits = outputs["pred_logits"].cpu().sigmoid()[0]  # (nq, 256)
    boxes = outputs["pred_boxes"].cpu()[0]  # (nq, 4)
    logits.shape[0]

    # filter output
    logits_filt = logits.clone()
    boxes_filt = boxes.clone()
    filt_mask = logits_filt.max(dim=1)[0] > box_threshold
    logits_filt = logits_filt[filt_mask]  # num_filt, 256
    boxes_filt = boxes_filt[filt_mask]  # num_filt, 4
    logits_filt.shape[0]

    # get phrase
    tokenlizer = model.tokenizer
    tokenized = tokenlizer(caption)
    # build pred
    pred_phrases = []
    for logit, box in zip(logits_filt, boxes_filt):
        pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
        if with_logits:
            pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
        else:
            pred_phrases.append(pred_phrase)

    return boxes_filt, pred_phrases

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


def show_box(box, ax, label):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
    ax.text(x0, y0, label)


def save_mask_data(output_dir, mask_list, box_list, label_list):
    value = 0  # 0 for background

    mask_img = torch.zeros(mask_list.shape[-2:])
    for idx, mask in enumerate(mask_list):
        mask_img[mask.cpu().numpy()[0] == True] = value + idx + 1
    plt.figure(figsize=(10, 10))
    plt.imshow(mask_img.numpy())
    plt.axis('off')
    plt.savefig(os.path.join(output_dir, 'mask.jpg'), bbox_inches="tight", dpi=300, pad_inches=0.0)

    json_data = [{
        'value': value,
        'label': 'background'
    }]
    for label, box in zip(label_list, box_list):
        value += 1
        try:
            name, logit = label.split('(')
        except:
            name = label
            logit = "1.0"
        logit = logit[:-1] # the last is ')'
        json_data.append({
            'value': value,
            'label': name,
            'logit': float(logit),
            'box': box.numpy().tolist(),
        })
    with open(os.path.join(output_dir, 'mask.json'), 'w') as f:
        json.dump(json_data, f)


def generate_mask(predictor, args, model=None):
    '''
    python grounded_sam_demo.py   --config GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py   --grounded_checkpoint groundingdino_swint_ogc.pth   --sam_checkpoint sam_vit_h_4b8939.pth   --input_image assets/demo1.jpg   --output_dir "outputs/biggest_bear"   --box_threshold 0.3   --text_threshold 0.25   --text_prompt "the biggest bear"   --device "cuda"
    '''
    # args.config = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
    # args.grounded_checkpoint = "groundingdino_swint_ogc.pth"
    # args.sam_checkpoint = "sam_vit_h_4b8939.pth"
    # args.use_sam_hq = False
    # args.sam_hq_checkpoint = None
    args.sam_version = "vit_h"
    args.device = "cuda"

    output_filename = args.get("output_filename", os.path.splitext(os.path.basename(args.input_image))[0])

    # cfg
    image_path = args.input_image
    text_prompt = args.text_prompt
    boxes = args.boxes
    output_dir = args.output_dir
    device = args.device

    # make dir
    os.makedirs(output_dir, exist_ok=True)

    output_image_filename = os.path.join(output_dir, f"{output_filename}_mask.jpg")
    output_mask_filename = os.path.join(output_dir, f'{output_filename}_mask.npy')
    if os.path.exists(output_image_filename) and os.path.exists(output_mask_filename):
        print(f"Skip {output_image_filename}")
        return
    # load image
    image_pil, image = load_image(image_path)

    if model is not None:

        box_threshold = args.box_threshold
        text_threshold = args.text_threshold
        # run grounding dino model
        boxes_filt, pred_phrases = get_grounding_output(
            model, image, text_prompt, box_threshold, text_threshold, device=device
        )
        if boxes_filt.size(0) == 0:
            print(f"Failed to extract bbs from {image_path},\n\ttext_prompt {text_prompt}")
            return False
        else:

            size = image_pil.size
            H, W = size[1], size[0]
            for i in range(boxes_filt.size(0)):
                boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
                boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
                boxes_filt[i][2:] += boxes_filt[i][:2]
        if len(boxes) != 0:
            boxes_filt = torch.cat([boxes_filt, torch.tensor(boxes)], dim=0)
            pred_phrases = pred_phrases + [text_prompt] * len(boxes)
    else:
        pred_phrases = [text_prompt] * len(boxes)
        boxes_filt = torch.tensor(boxes)
        
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    predictor.set_image(image)

    boxes_filt = boxes_filt.cpu()
    transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)

    try:
        masks, _, _ = predictor.predict_torch(
            point_coords = None,
            point_labels = None,
            boxes = transformed_boxes.to(device),
            multimask_output = False,
        )
    except Exception as e:
        print(e)
        print(f"Failed to extract mask from {image_path},\n\ttext_prompt {text_prompt}")
        return
    np.save(output_mask_filename, masks.cpu().numpy())
    # draw output image
    plt.figure(figsize=(10, 10))
    plt.imshow(image)
    for mask in masks:
        show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
    for box, label in zip(boxes_filt, pred_phrases):
        show_box(box.numpy(), plt.gca(), label)

    plt.axis('off')
    plt.savefig(
        output_image_filename,
        bbox_inches="tight", dpi=300, pad_inches=0.0
    )
    plt.close()

    save_mask_data(output_dir, masks, boxes_filt, pred_phrases)
    return True


def main(
            annotation_file='../data/gqa/val_balanced_gqa_coco_captions_region_captions_scene_graphs.jsonl',
            image_folder='../data/vg/',
            output_folder='../data/vg_samples/remove_anything/gsam_masks',
            debug=False,
            mode="box"):
    device = "cuda"
    sam_checkpoint = "sam_vit_h_4b8939.pth"
    use_sam_hq = False
    sam_hq_checkpoint = None
    sam_version = "vit_h"
    output_folder = output_folder+"_"+mode
    if "text" in mode.split("_"):
        config_file = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
        grounded_checkpoint = "groundingdino_swint_ogc.pth" 
        model = load_model(config_file, grounded_checkpoint, device=device)
    else:
        model = None

    # initialize SAM
    if use_sam_hq:
        predictor = SamPredictor(sam_hq_model_registry[sam_version](checkpoint=sam_hq_checkpoint).to(device))
    else:
        predictor = SamPredictor(sam_model_registry[sam_version](checkpoint=sam_checkpoint).to(device))

    data = [json.loads(line.strip()) for line in open(annotation_file, "r")]
    if debug:
        data = data[:100]
    for d in tqdm(data):
        file_id = d["vg_id"]
        # get the list of objects that are with the same name
        name2obj = defaultdict(list)
        for obj_id, item in d["scene_graph"].items():
            if len(item["relations"]) == 0 or len(item["attributes"]) == 0:
                continue
            name = item["name"]
            if name in ["background"]:
                continue
            name2obj[name].append(obj_id)
        save_raw = False
        for name, obj_ids in tqdm(name2obj.items(), desc=f"Processing {file_id}"):
            obj_ids = sorted(obj_ids)
            text_prompt = name
            image_path = os.path.join(image_folder, f'{file_id}.jpg')
            assert os.path.exists(image_path)
            mask_args = edict()
            mask_args.input_image = image_path
            mask_args.output_dir = output_folder
            mask_args.output_filename = f'{file_id}.{"-".join(obj_ids)}'
            mask_args.text_prompt = text_prompt
            mask_args.boxes = []
            if "box" in mode.split("_"):
                for obj_id in obj_ids:
                    item = d["scene_graph"][obj_id]
                    boxes = [item["x"], item["y"], item["x"] + item["w"], item["y"] + item["h"]]
                    mask_args.boxes.append(boxes)
            if model is not None:
                mask_args.box_threshold = 0.4
                mask_args.text_threshold = 0.4
            output_exists =generate_mask(predictor, mask_args, model=model)
            if output_exists:
                save_raw = True
            # if debug:
            #     break
        if save_raw:
            shutil.copyfile(image_path, os.path.join(output_folder, f'{file_id}.jpg'))

if __name__ == '__main__':
    # parser = argparse.ArgumentParser()
    # parser.add_argument('--annotation_file', type=str, default='../data/gqa/val_balanced_gqa_coco_captions_region_captions_scene_graphs.jsonl')
    # parser.add_argument('--image_folder', type=str, default='../data/vg/') 
    # parser.add_argument('--output_folder', type=str, default='../data/vg_samples/remove_anything/gsam_masks') 
    # args = parser.parse_args()
    # main(args)
    from fire import Fire
    Fire(main)
